// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include "base_reference_test.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/rdft.hpp"

using namespace reference_tests;
using namespace ov;

namespace {
struct RDFTParams {
    template <class T>
    RDFTParams(const Shape& input_shape,
               const Shape& expected_shape,
               const element::Type_t& input_type,
               const element::Type_t& expected_type,
               const std::vector<T>& input_value,
               const std::vector<T>& expected_value,
               const std::shared_ptr<op::v0::Constant>& axes,
               const std::shared_ptr<op::v0::Constant>& signal) {
        m_input_shape = input_shape;
        m_expected_shape = expected_shape;
        m_input_type = input_type;
        m_expected_type = expected_type;
        m_input_value = CreateTensor(input_type, input_value);
        m_expected_value = CreateTensor(expected_type, expected_value);
        m_axes = axes;
        m_signal = signal;
    }

    Shape m_input_shape;
    Shape m_expected_shape;
    element::Type_t m_input_type;
    element::Type_t m_expected_type;
    ov::Tensor m_input_value;
    ov::Tensor m_expected_value;
    std::shared_ptr<op::v0::Constant> m_axes;
    std::shared_ptr<op::v0::Constant> m_signal;
};

class ReferenceRDFTLayerTest : public testing::TestWithParam<RDFTParams>, public CommonReferenceTest {
public:
    void SetUp() override {
        auto params = GetParam();
        if (params.m_signal != NULL) {
            function = CreateFunctionWithSignal(params);
        } else {
            function = CreateFunction(params);
        }

        inputData = {params.m_input_value};
        refOutData = {params.m_expected_value};
    }

    static std::string getTestCaseName(const testing::TestParamInfo<RDFTParams>& obj) {
        const auto param = obj.param;
        std::ostringstream result;

        result << "input_shape1=" << param.m_input_shape << "; ";
        result << "output_shape=" << param.m_expected_shape << "; ";
        result << "input_type1=" << param.m_input_type << "; ";
        result << "output_type=" << param.m_expected_type << "; ";
        result << "transpose1=" << param.m_axes;

        return result.str();
    }

private:
    static std::shared_ptr<Model> CreateFunction(RDFTParams& p) {
        auto in = std::make_shared<op::v0::Parameter>(p.m_input_type, p.m_input_shape);
        auto rdft = std::make_shared<op::v9::RDFT>(in, p.m_axes);

        return std::make_shared<ov::Model>(rdft, ParameterVector{in});
    }

    static std::shared_ptr<Model> CreateFunctionWithSignal(RDFTParams& p) {
        auto in = std::make_shared<op::v0::Parameter>(p.m_input_type, p.m_input_shape);
        auto rdft = std::make_shared<op::v9::RDFT>(in, p.m_axes, p.m_signal);

        return std::make_shared<ov::Model>(rdft, ParameterVector{in});
    }
};

TEST_P(ReferenceRDFTLayerTest, CompareWithHardcodedRefs) {
    Exec();
}

static const std::vector<float> input_data = {
    0.10606491,  0.7454715,   0.57231355,  0.4582412,   0.3847059,   0.27398932, 0.66796243, 0.395475,
    0.2815729,   0.7799197,   0.59909415,  0.12294636,  0.38957402,  0.97498834, 0.46759892, 0.14017141,
    0.04206858,  0.7279963,   0.61560553,  0.9027321,   0.6226334,   0.2601217,  0.5555177,  0.40498647,
    0.14175586,  0.57774633,  0.52652127,  0.9385691,   0.9588788,   0.9844318,  0.23095612, 0.09707925,
    0.24574867,  0.6907577,   0.1974319,   0.8295272,   0.34612727,  0.51401484, 0.66115797, 0.9336245,
    0.06690067,  0.7468897,   0.39028263,  0.53575844,  0.060429193, 0.8913558,  0.77787375, 0.6701197,
    0.7350527,   0.6636995,   0.18176624,  0.8629976,   0.45142895,  0.6497297,  0.159372,   0.40598175,
    0.7988516,   0.7291543,   0.07090418,  0.7697132,   0.4972157,   0.7669217,  0.67975855, 0.13026066,
    0.6587437,   0.24532892,  0.24545169,  0.83795583,  0.105490535, 0.7264323,  0.94568557, 0.7216649,
    0.14389831,  0.7930531,   0.70895344,  0.9724701,   0.9775157,   0.49999878, 0.65569246, 0.26876843,
    0.63248956,  0.85201293,  0.5689624,   0.023386303, 0.5546464,   0.36860028, 0.9603114,  0.39123482,
    0.0380728,   0.89212376,  0.14387614,  0.63858676,  0.10003748,  0.8906635,  0.06681054, 0.7458642,
    0.45452347,  0.54724604,  0.6496482,   0.7818356,   0.6608355,   0.77711326, 0.24588613, 0.013456763,
    0.355845,    0.80388206,  0.027993264, 0.73677206,  0.52755004,  0.9052324,  0.54311025, 0.5367805,
    0.4131242,   0.7752338,   0.109669454, 0.13664648,  0.7828739,   0.9083969,  0.5247593,  0.7493595,
    0.19275227,  0.007190853, 0.6087981,   0.344136,    0.46909887,  0.41924855, 0.7072913,  0.19932869,
    0.5303847,   0.651384,    0.06686331,  0.9717932,   0.65702224,  0.11786682, 0.3154073,  0.88923013,
    0.5564087,   0.91047823,  0.28466642,  0.0934668,   0.88953066,  0.9919338,  0.18322521, 0.8185455,
    0.566391,    0.014207997, 0.29673064,  0.6347744,   0.6801958,   0.39601147, 0.34374171, 0.7216888,
    0.6152569,   0.76679546,  0.5860851,   0.4276813,   0.79339284,  0.13130653, 0.68764234, 0.053128112,
    0.02611321,  0.2982243,   0.7618372,   0.3331729,   0.5468192,   0.15707079, 0.28592056, 0.15286565,
    0.9368963,   0.350671,    0.4336494,   0.08934934,  0.41172776,  0.5850259,  0.70730376, 0.8598349,
    0.088788144, 0.26711187,  0.8002491,   0.19422275,  0.8312039,   0.5198718,  0.40111357, 0.98375803,
    0.77703434,  0.037818834, 0.704231,    0.689808,    0.17102319,  0.42153922, 0.7278252,  0.8030207,
    0.9101717,   0.0199644,   0.13768466,  0.55669,     0.17991355,  0.6720098,  0.7733328,  0.20881335};

static const std::vector<float> expected_rdft1d_results_1 = {
    4.6657147,   -1.1622906e-06, 0.21456887,    -0.14946258, -0.20476034,  -0.37063062,
    -0.31414136, 0.5099413,      -1.1779613,    0.07057127,  -0.64047664,  -1.0058284e-07,
    4.982774,    -1.1771917e-06, 0.6607505,     0.18829148,  -0.9772357,   1.4243596,
    0.8640026,   0.34923682,     0.33401352,    0.25859502,  -0.7548928,   8.940697e-08,
    5.9711604,   -1.4901161e-06, 0.5638976,     1.5429841,   -0.52065414,  0.24638398,
    -0.27140495, 0.5040715,      0.5360231,     0.3234269,   -0.36054826,  1.7508864e-07,
    4.7464237,   -1.2218952e-06, -0.29650804,   0.80609477,  -0.161426,    1.0022418,
    -0.50812817, 0.7967348,      0.4394225,     -0.1588624,  -1.3835809,   -7.4505806e-08,
    5.53836,     -1.7136335e-06, -0.38635445,   0.8284859,   -0.23278837,  -0.63777345,
    -0.93614054, 0.3215857,      -0.14075133,   -0.67071164, -1.4772836,   2.0861626e-07,
    5.0798974,   -1.5944242e-06, 0.056767445,   0.03468219,  -0.1497254,   -0.9672509,
    0.2603209,   0.69644475,     -0.9208536,    0.006730467, -1.7552528,   2.682209e-07,
    4.893558,    -1.6242266e-06, 0.6719861,     -0.13982919, 0.064845346,  -0.39896214,
    0.21785057,  -0.5099982,     -0.65526295,   1.4383471,   -0.52023906,  2.5331974e-07,
    6.687699,    -1.5497208e-06, -0.7423769,    0.09968524,  1.052381,     -0.21306956,
    0.5875206,   -0.3038844,     0.3991575,     -1.1895186,  0.17579001,   3.874302e-07,
    5.2818384,   -1.1026859e-06, 0.5087582,     0.106959194, 1.1816688,    -0.87592727,
    0.03740315,  0.5197907,      -1.3198637,    0.6398836,   0.22712436,   2.2351742e-08,
    5.0190897,   -1.5646219e-06, -0.087282926,  0.50819266,  -0.28002462,  0.29240948,
    -0.32303664, 0.38377762,     -0.0051696897, -0.99301195, -2.189299,    2.0861626e-07,
    5.0545654,   -1.5795231e-06, 0.9146397,     0.83839166,  0.870533,     0.17405808,
    -0.56308234, -0.7806684,     0.26397777,    0.6880482,   -1.4183462,   2.682209e-07,
    5.479953,    -1.2665987e-06, 0.49444157,    0.7534672,   -0.76784146,  -0.4507342,
    0.88815784,  0.6985409,      -0.2727425,    -0.25027415, -0.7328796,   2.682209e-07,
    4.1296124,   -5.662441e-07,  -0.46133032,   0.30635798,  -0.18225375,  0.42515472,
    -0.5484285,  0.9704039,      -0.35255045,   0.17549685,  0.8870368,    -3.1292439e-07,
    4.8632016,   -1.8924475e-06, -0.6926452,    0.025076404, -0.039108217, -1.7492937,
    -0.8120377,  -0.85315156,    -0.0022608787, 0.45002514,  -1.1024668,   3.501773e-07,
    5.4715447,   -1.4901161e-06, 1.1176248,     -0.2109062,  -0.27492502,  0.08983741,
    1.1903813,   -1.007312,      -0.20150042,   -0.83919466, -0.23939973,  4.917383e-07,
    5.1267176,   -9.983778e-07,  -0.44803134,   -0.8066604,  -0.3435102,   -0.41692197,
    -0.22457689, -0.1076939,     -0.29129186,   -1.1880502,  0.9255183,    -1.6391277e-07,
    3.8495903,   -5.5134296e-07, 0.09505272,    -0.12751618, -1.1264827,   0.5068884,
    -1.055237,   -0.19516481,    -0.34035242,   -0.15379356, 1.2655814,    -2.6077032e-07,
    4.4372616,   -9.23872e-07,   -0.72962606,   -0.23475963, -0.04278487,  1.1032158,
    -0.558924,   -0.5300043,     1.0578637,     -0.2466627,  0.44617313,   -7.8231096e-08,
    5.5374002,   -1.4156103e-06, 0.016273111,   -0.5989829,  -0.19913958,  0.013256833,
    1.8512837,   0.14526272,     -0.39700353,   -0.07573915, 0.23181,      2.9429793e-07,
    4.989425,    -1.4901161e-06, 1.0391837,     0.16554561,  -0.22647032,  -1.0689808,
    -0.84556,    -0.82779336,    0.9430445,     0.37618563,  0.4684292,    -9.685755e-08};

static const std::vector<float> expected_rdft1d_results_2 = {
    2.266797,  -8.195639e-08,  -0.37842733,  -0.41015846,  -0.48980892,  -0.10356337,
    2.5542018, -2.2351742e-08, -0.3223713,   0.671882,     0.54300576,   -0.35418037,
    1.985015,  -2.2351742e-08, -0.030243821, -0.20105253,  0.59431964,   0.07358998,
    1.4619737, -7.450581e-09,  -0.4356845,   0.35701087,   0.28208786,   -0.36424285,
    1.8002605, -1.1920929e-07, -0.43280697,  -0.56735414,  -0.30007166,  -0.541847,
    2.3052943, -1.2293458e-07, -0.39316025,  -0.5526293,   -0.30507135,  -0.6021758,
    2.7329001, -6.7055225e-08, 0.28245124,   -0.42586988,  -0.40586215,  0.4590181,
    3.3132548, -5.9604645e-08, 0.6297612,    0.3694744,    0.077824846,  -0.6248544,
    2.6314974, -2.9802322e-08, 0.58795106,   -0.60349375,  -0.3224758,   0.34408605,
    1.8399743, -9.685755e-08,  -0.43963802,  -0.079073176, -0.120658875, -1.0880115,
    2.0531366, -4.4703484e-08, 0.80112594,   -0.53726834,  -0.17560546,  -0.026561722,
    2.3779182, -9.685755e-08,  -0.21852754,  -0.19336401,  0.38734403,   -0.5954362,
    1.6219761, 7.450581e-09,   -0.43100592,  0.28373614,   0.101898566,  0.52321124,
    2.128953,  -1.4901161e-07, -0.1622684,   -0.94116735,  -0.7350497,   0.12695336,
    3.449626,  -8.940697e-08,  0.56062996,   -0.031283244, -0.06161648,  -0.8543532,
    3.033568,  -8.195639e-08,  -0.37023768,  -0.03989461,  -0.28719214,  -0.22382751,
    1.9661667, -1.4901161e-08, -0.59863573,  -0.015534669, -0.31916466,  0.55380434,
    2.227056,  -5.2154064e-08, -0.12656188,  0.6895717,    0.097157195,  0.19840825,
    3.5129817, -2.1234155e-07, 0.11158541,   0.5870459,    0.20993343,   -0.40297145,
    2.5986667, 0.0,            0.26602313,   -1.1560227,   0.2542065,    0.45556274};

static const std::vector<float> expected_rdft1d_results_3 = {
    4.665715,    -1.6093254e-06, -0.5430559,   -0.5752678,     -0.37596112,  -1.1571281,
    -0.46793216, -0.94566363,    0.6854232,    -0.3444838,     -0.674704,    0.5946392,
    -0.64047587, 1.3560057e-06,  4.9827743,    -1.7434359e-06, -0.43517,     -0.049020194,
    -1.4773891,  -1.0811031,     1.2506557,    0.5371344,      1.2869358,    -0.14998645,
    0.8555907,   0.3693859,      -0.7548918,   1.5944242e-06,  5.971161,     -1.5199184e-06,
    -1.2643411,  0.85635287,     -0.1801207,   -1.7264944,     0.6412285,    -0.4787441,
    0.82227707,  0.65098876,     0.9114491,    0.40323836,     -0.36054718,  1.2852252e-06,
    4.7464237,   -1.66893e-06,   -1.5010594,   0.2253451,      -0.87915635,  -0.4252541,
    0.4976693,   -0.6554581,     0.928985,     0.8035921,      0.6578763,    -0.15220329,
    -1.3835799,  1.0430813e-06,  5.5383606,    -1.4901161e-06, -1.619024,    -0.10987502,
    0.20661727,  -1.3774645,     -0.3057741,   -1.0960662,     0.2971667,    0.46700704,
    -0.20812088, -0.602368,      -1.4772825,   9.3877316e-07,  5.0798974,    -1.758337e-06,
    -0.7421876,  -0.61749315,    0.21938956,   -1.3415859,     -0.838238,    -0.6598083,
    1.0601404,   -0.7129184,     -0.27083004,  0.31763482,     -1.7552516,   1.4677644e-06,
    4.893558,    -1.4975667e-06, -0.06445231,  -0.55879503,    0.08908144,   -1.2869594,
    0.33623943,  -0.7704663,     -0.047739983, -1.0678453,     0.48350462,   1.5768427,
    -0.52023804, 1.1697412e-06,  6.687699,     -1.3113022e-06, -1.292419,    -1.2920969,
    1.2041754,   -0.2943018,     1.1889167,    -0.66985166,    1.1336832,    -0.13731277,
    0.008011267, -0.9506076,     0.1757915,    1.1026859e-06,  5.2818394,    -1.4305115e-06,
    -0.25987166, -0.48605326,    0.90237427,   -0.8028362,     -0.3040653,   -1.6981151,
    1.1215456,   -0.7120959,     -0.4195284,   1.3941492,      0.22712523,   8.046627e-07,
    5.01909,     -1.7881393e-06, -1.1856917,   -0.10931289,    -0.5164983,   -0.9724103,
    0.30577338,  -0.72837675,    0.89680094,   0.21036407,     -0.052024096, -0.9455472,
    -2.1892984,  1.4305115e-06,  5.054565,     -1.5050173e-06, -0.3471575,   0.40542153,
    0.36438322,  -0.9765247,     1.2703501,    -1.7359983,     -0.1160066,   -0.25323528,
    0.9753329,   0.5339062,      -1.418345,    9.834766e-07,   5.4799523,    -1.7285347e-06,
    -0.7905842,  0.093313254,    0.068526804,  -1.8504739,     -0.01845923,  0.26084417,
    1.5358877,   -0.4159652,     0.089752786,  0.089908056,    -0.7328786,   1.4007092e-06,
    4.129612,    -9.536743e-07,  -1.2393575,   -0.28046644,    -0.58673245,  -0.39608067,
    -0.12385368, -0.53435826,    0.77853805,   0.7645384,      -0.18040559,  0.6678516,
    0.88703763,  8.046627e-07,   4.8632016,    -1.0430813e-06, -1.1780663,   -1.0952923,
    1.1691413,   -1.4023741,     -0.546494,    -0.92614484,    -1.1796933,   -0.31762218,
    0.25592417,  0.0959474,      -1.1024656,   1.013279e-06,   5.471545,     -1.6987324e-06,
    0.35812324,  -0.66833705,    0.07725692,   -1.6537004,     1.6561611,    0.051166296,
    0.865453,    -1.1392289,     -0.23588535,  -0.5480979,     -0.2393986,   1.3411045e-06,
    5.126718,    -9.23872e-07,   -0.6379836,   -1.6675751,     0.013057679,  -0.9891113,
    0.20881936,  -0.30439606,    0.37222707,   0.25244698,     -0.9197892,   -0.77782196,
    0.9255192,   1.1101365e-06,  3.8495903,    -7.4505806e-07, -0.63088936,  -0.4556699,
    -1.1905057,  -1.2522144,     0.46207082,   -0.31992733,    -0.4309795,   0.74295896,
    -0.6106033,  0.18823686,     1.2655822,    7.748604e-07,   4.4372616,    -7.0780516e-07,
    -1.1016369,  -1.0079124,     -0.6083025,   -0.0011255145,  1.4406854,    -0.2912693,
    -0.26610214, 0.87299407,     0.69553405,   -0.45576566,    0.44617438,   7.4505806e-07,
    5.5374007,   -1.5944242e-06, -0.32642078,  -1.3683549,     0.079301864,  -0.83741367,
    0.67391664,  0.69433576,     1.6423957,    -1.1923066,     0.0334223,    0.37603495,
    0.23181117,  1.4156103e-06,  4.9894247,    -7.748604e-07,  0.1788401,    -0.39274544,
    0.78422666,  -2.1340246,     0.5487572,    -0.8765497,     -0.7899384,   0.5434137,
    0.91613716,  0.08274247,     0.46843058,   8.34465e-07
};

const std::vector<float> expected_rdft2d_results = {
    52.8665,     -2.9623508e-05, 1.1642078,    3.826082,    -0.22771922,  -0.49822173,
    -0.3857528,  3.2676966,      -2.5112464,   -0.27454787, -8.678656,    3.7550926e-06,
    -0.818072,   0.8330209,      3.4618711,    -0.2419473,  1.7408192,    5.744002,
    1.8477443,   2.039329,       0.3268112,    -2.7421296,  0.6809025,    1.7613728,
    -2.294264,   -0.8984407,     -0.2868184,   -3.2426705,  -0.801461,    -0.58971727,
    -1.463435,   -2.5413132,     0.116907075,  -0.5013529,  -2.8377397,   -2.8455539,
    -0.13475686, -1.3145845,     -2.2820292,   -0.199,      -0.056986623, 0.12560216,
    -0.589707,   -1.7577857,     -0.5274223,   -1.0395792,  0.53813136,   -1.7159984,
    0.22503978,  2.902198,       -1.8643543,   -1.8789856,  2.1722724,    -2.068454,
    0.59446484,  0.6067899,      1.5525781,    1.7612485,   1.1877432,    -0.48152098,
    -0.16525066, 1.5497208e-06,  1.9815066,    0.55218977,  0.80434155,   -3.575598,
    -2.1471107,  -0.57691807,    -3.004384,    3.8775828,   3.1358109,    -6.2584877e-07,
    0.22504184,  -2.9021916,     1.0378464,    0.9877456,   0.38395065,   -1.6089694,
    -0.5107449,  1.8621777,      -4.960479,    -1.8983803,  1.187743,     0.48151842,
    -0.1347583,  1.3145843,      -0.9968031,   -1.3782079,  0.9922035,    1.6614089,
    -0.83039653, -0.043888614,   1.9431384,    -1.6448143,  0.5381324,    1.7159982,
    -2.2942696,  0.8984335,      1.3057998,    -0.26607463, -3.2994738,   -1.9240448,
    1.4963659,   2.8365738,      -4.691832,    1.2995429,   -2.8377357,   2.8455553,
    -0.8180722,  -0.8330165,     -1.3755352,   0.34623986,  -3.7555497,   -0.9723124,
    -1.1528367,  -0.593254,      -0.023679793, 1.8681414,   0.6809023,    -1.7613728,
    48.939255,   -2.4735928e-05, 1.3455832,    0.11001387,  -2.3319814,   -1.3735183,
    -0.6780232,  -2.4875786,     0.40718403,   -1.0639579,  0.7314569,    -1.2665987e-07,
    0.97006464,  -0.30789328,    3.3290033,    2.7749023,   -0.7520597,   -0.98800826,
    1.3100916,   1.1514524,      1.1085359,    4.348257,    -2.839456,    2.4404035,
    0.9518837,   2.1538901,      3.8438358,    2.410589,    3.0649068,    0.95690995,
    2.2213395,   0.66509914,     -0.4409917,   -0.37408838, -0.6316552,   -1.5842111,
    -0.72352415, -2.5862057,     0.2678757,    0.610149,    2.9564474,    0.08470708,
    -2.0889034,  -8.370071,      -0.16373271,  2.0413866,   -3.3811545,   2.0487003,
    0.0316903,   -1.078939,      -2.5515578,   -0.16135174, -0.17406325,  1.2709827,
    -0.67006403, -1.6342779,     0.42163712,   2.1418998,   -0.96614444,  1.9175051,
    -0.8538456,  2.8014183e-06,  2.0189362,    0.30467552,  0.5074463,    3.7919073,
    2.427857,    0.7526233,      -2.4620402,   0.65359443,  0.7219074,    -2.3841858e-07,
    0.03169757,  1.0789458,      -2.1129081,   -1.0250417,  4.8181386,    -0.39162922,
    -1.2349386,  1.8470186,      -0.49495277,  -1.5516026,  -0.96614635,  -1.9175065,
    -0.7235237,  2.5862021,      0.677946,     2.0370173,   -0.29536027,  0.6505451,
    -2.8572361,  2.3176546,      3.4459226,    1.1869265,   -3.3811545,   -2.048697,
    0.95187366,  -2.1538982,     1.808088,     -1.1755496,  -2.7418838,   -1.6770658,
    -3.5766084,  -2.8320727,     -0.02944839,  -1.6522555,  -0.63165283,  1.5842092,
    0.9700667,   0.30789307,     0.5195943,    2.4985125,   3.6537378,    -0.5842519,
    -0.4843334,  0.78346854,     0.84766304,   1.1503224,   -2.839459,    -2.440402};

const std::vector<float> expected_rdft2d_results_2 = {
    25.904434,   -8.46386e-06,  -5.3626504,  0.3475349,   -2.7060094,   -5.767444,
    1.615847,    -2.6387978,    4.020789,    1.4271183,   1.5420923,    0.6126925,
    -4.6167765,  5.5730343e-06, -0.753784,   -0.19148755, 1.4881928,    -2.7645326,
    -0.39467168, 1.014636,      0.5598,      -1.7654291,  -0.91835654,  -2.3019042,
    -0.49356225, -0.8411435,    0.080773115, -1.2883577,  -0.5341466,   1.4913602,
    -0.30008763, -0.5831754,    1.7365295,   1.821624,    -0.08851206,  -1.622279,
    -0.27249795, -0.834725,     -0.6706438,  0.4766277,   0.62642634,   0.5483514,
    -0.5341469,  -1.4913592,    0.8286207,   0.35826343,  -1.0869694,   -1.4876881,
    -1.6723244,  -0.06565219,   0.16255295,  0.5317876,   -0.75649667,  1.2447717,
    0.6264261,   -0.5483517,    -0.7537827,  0.19148779,  0.6306459,    -0.23442982,
    0.57131517,  -1.366768,     -2.7544713,  1.3638397,   0.43463084,   -0.5446956,
    -2.9949086,  1.4802479,     0.080771565, 1.2883584,   24.998875,    -7.390976e-06,
    -3.1970425,  -1.5453612,    1.0925753,   -6.279154,   2.237704,     -2.8844912,
    1.8841789,   -1.3615136,    0.90471864,  0.8395144,   -2.6060505,   4.976988e-06,
    1.1634235,   0.42319643,    2.678257,    2.4692535,   0.34259582,   0.43598562,
    2.748452,    0.88622695,    2.2745323,   -2.8840196,  1.8120161,    -0.27884078,
    -1.5445104,  -0.7000726,    -1.0264511,  -0.7026249,  -1.071573,    1.062395,
    -0.64628685, -0.36214483,   -0.5110928,  -1.0534683,  -2.786768,    2.6113648,
    0.94799054,  0.53423727,    -0.69832724, 2.1821892,   -1.0264513,   0.70262754,
    -0.41705567, -0.17140968,   1.4991179,   2.9674625,   -0.012362838, -3.8260121,
    -1.5786235,  -0.32526863,   1.2857957,   1.7469958,   -0.6983267,   -2.1821907,
    1.1634252,   -0.42319855,   0.2716269,   0.21222934,  -0.46608746,  -1.6447732,
    1.8890494,   -1.8022469,    -0.37335354, 0.69326025,  -0.07385725,  -0.1723765,
    -1.5445105,  0.7000739};

const std::vector<float> expected_rdft3d_results = {
    101.805756,  -5.2273273e-05, 2.5097876,    3.936094,     -2.5597036,  -1.8717405,
    -1.0637736,  0.7801182,      -2.1040666,   -1.3385094,   -7.9471993,  2.026558e-06,
    0.15199316,  0.52512753,     6.7908745,    2.5329556,    0.98875976,  4.755993,
    3.157838,    3.190782,       1.4353466,    1.6061276,    -2.158554,   4.201776,
    -1.3423799,  1.2554499,      3.5570183,    -0.8320818,   2.263445,    0.36719292,
    0.7579028,   -1.8762131,     -0.32408538,  -0.87544185,  -3.4693956,  -4.429764,
    -0.85828185, -3.9007902,     -2.0141544,   0.4111499,    2.8994608,   0.21030927,
    -2.6786098,  -10.127857,     -0.6911557,   1.0018079,    -2.8430226,  0.33270124,
    0.25672907,  1.8232578,      -4.4159126,   -2.040338,    1.9982092,   -0.7974717,
    -0.07559925, -1.0274884,     1.9742157,    3.9031482,    0.22159882,  1.4359848,
    -1.0190966,  3.2186508e-06,  4.0004425,    0.8568655,    1.3117876,   0.2163087,
    0.28074512,  0.17570588,     -5.466423,    4.531178,     3.857718,    -1.2516975e-06,
    0.2567385,   -1.823246,      -1.0750613,   -0.037295938, 5.20209,     -2.0005994,
    -1.7456844,  3.7091968,      -5.45543,     -3.4499822,   0.22159535,  -1.4359887,
    -0.8582816,  3.9007854,      -0.31885874,  0.65880924,   0.6968423,   2.3119528,
    -3.6876333,  2.273767,       5.38906,      -0.45788872,  -2.8430223,  -0.33269957,
    -1.3423961,  -1.2554631,     3.1138885,    -1.4416232,   -6.0413575,  -3.6011095,
    -2.080242,   0.0045015216,   -4.7212796,   -0.3527125,   -3.4693892,  4.429763,
    0.15199506,  -0.52512354,    -0.85594195,  2.8447511,    -0.10181111, -1.5565643,
    -1.6371696,  0.19021615,     0.8239815,    3.018465,     -2.158556,   -4.2017746,
    3.9272437,   -3.9339066e-06, -0.18137527,  3.7160687,    2.1042633,   0.8752967,
    0.29226887,  5.755277,       -2.9184306,   0.78941,      -9.410112,   3.0100346e-06,
    -1.7881365,  1.140914,       0.13286811,   -3.01685,     2.4928799,   6.7320104,
    0.5376528,   0.88787735,     -0.78172505,  -7.0903873,   3.5203578,   -0.6790314,
    -3.246148,   -3.0523329,     -4.1306543,   -5.653259,    -3.866367,   -1.5466263,
    -3.6847744,  -3.2064118,     0.5578996,    -0.12726665,  -2.2060838,  -1.2613428,
    0.588767,    1.2716217,      -2.5499039,   -0.8091496,   -3.0134337,  0.0408957,
    1.4991964,   6.6122847,      -0.36368948,  -3.0809648,   3.9192853,   -3.764699,
    0.19334978,  3.9811373,      0.68720365,   -1.717634,    2.346336,    -3.3394372,
    1.2645291,   2.241068,       1.1309403,    -0.3806507,   2.1538877,   -2.3990266,
    0.6885946,   -1.4901161e-06, -0.037429705, 0.24751475,   0.2968948,   -7.367506,
    -4.574969,   -1.329541,      -0.5423446,   3.2239883,    2.4139037,   2.9802322e-07,
    0.19334424,  -3.9811373,     3.1507545,    2.0127864,    -4.4341884,  -1.2173393,
    0.72419256,  0.015158802,    -4.4655256,   -0.34677732,  2.1538897,   2.3990245,
    0.5887663,   -1.2716188,     -1.6747494,   -3.415226,    1.2875631,   1.0108626,
    2.0268395,   -2.3615427,     -1.502785,    -2.8317401,   3.919288,    3.764695,
    -3.2461433,  3.0523314,      -0.5022881,   0.9094755,    -0.55759126, -0.24697942,
    5.0729737,   5.668646,       -4.662384,    2.9517999,    -2.2060819,  1.2613468,
    -1.7881389,  -1.1409098,     -1.8951292,   -2.1522717,   -7.4092865,  -0.38806117,
    -0.6685039,  -1.3767233,     -0.8713439,   0.71781945,   3.5203605,   0.6790297};

const std::vector<float> expected_rdft3d_results_2 = {
    50.90331,     -1.4543533e-05, -8.559692,   -1.1978266,     -1.6134334,  -12.046599,
    3.8535514,    -5.5232873,     5.9049683,   0.065603495,    2.4468107,   1.4522064,
    -7.222825,    1.2278557e-05,  0.40963984,  0.231709,       4.16645,     -0.29528028,
    -0.052075505, 1.450621,       3.3082519,   -0.8792013,     1.356175,    -5.1859245,
    1.3184534,    -1.1199851,     -1.4637363,  -1.9884299,     -1.5605974,  0.7887349,
    -1.3716602,   0.47921878,     1.0902424,   1.4594792,      -0.59960556, -2.6757474,
    -3.0592656,   1.7766399,      0.27734682,  1.0108652,      -0.07190053, 2.7305403,
    -1.5605986,   -0.78873086,    0.41156515,  0.18685403,     0.4121489,   1.4797752,
    -1.6846865,   -3.8916636,     -1.4160703,  0.20651829,     0.52929974,  2.9917672,
    -0.07190076,  -2.7305427,     0.4096415,   -0.23171037,    0.9022726,   -0.022200808,
    0.10522783,   -3.0115416,     -0.8654218,  -0.4384073,     0.061277367, 0.14856634,
    -3.0687659,   1.3078697,      -1.4637384,  1.9884316,      25.904425,   -24.998884,
    -6.9080105,   3.5445771,      -8.985163,   -6.860018,      -1.2686447,  -4.8765025,
    2.6592734,    -0.45706248,    2.3816066,   -0.29202732,    -4.6167727,  2.6060565,
    -0.33058774,  -1.3549114,     3.9574459,   -5.44279,       0.041313916, 0.67204094,
    1.446027,     -4.5138807,     -3.8023772,  -4.576436,      -0.7724026,  -2.6531591,
    -0.6192993,   0.25615194,     -1.2367722,  2.5178113,      0.7623075,   0.48839718,
    1.3743844,    2.4679115,      -1.1419809,  -1.1111865,     2.3388672,   1.9520425,
    -0.13640736,  -0.47136223,    2.8086162,   1.2466785,      0.16848034,  -0.46490768,
    0.6572111,    0.7753189,      1.8804929,   -2.9868064,     -5.498336,   -0.053289652,
    -0.16271627,  2.1104114,      0.9904991,   -0.041024223,   -1.5557647,  0.14997506,
    -1.1769819,   -0.9719368,     0.8428756,   -0.5060569,     -1.0734584,  -0.9006812,
    -4.556718,    -0.5252099,     1.1278908,   -0.17134166,    -3.1672862,  1.5541049,
    0.78084624,   2.8328683,      0.90555733,  -1.3709068e-06, -2.1656086,  1.8928962,
    -3.7985847,   0.511709,       -0.62185717, 0.24569236,     2.1366088,   2.7886305,
    0.6373716,    -0.2268233,     -2.0107267,  5.662441e-07,   -1.9172084,  -0.6146841,
    -1.1900643,   -5.233785,      -0.73726743, 0.5786506,      -2.188651,   -2.6516552,
    -3.1928902,   0.58211625,     -2.305578,   -0.5623034,     1.6252834,   -0.58828497,
    0.49230486,   2.1939852,      0.7714851,   -1.6455705,     2.382816,    2.1837692,
    0.4225806,    -0.56881106,    2.514269,    -3.4460905,     -1.618634,   -0.057608932,
    1.3247533,    -1.6338379,     0.49230492,  -2.1939862,     1.2456759,   0.5296728,
    -2.5860875,   -4.45515,       -1.659962,   3.7603593,      1.7411764,   0.8570565,
    -2.0422916,   -0.50222373,    1.3247528,   1.633839,       -1.9172082,  0.6146865,
    0.35901868,   -0.44665974,    1.0374024,   0.27800465,     -4.6435204,  3.1660864,
    0.8079842,    -1.2379556,     -2.921052,   1.6526239,      1.6252828,   0.588284,
    25.90444,     24.998867,      -3.817289,   -2.8495073,     3.573144,    -4.6748676,
    4.500339,     -0.40109348,    5.382302,    3.3112957,      0.7025763,   1.5174108,
    -4.616783,    -2.6060438,     -1.1769816,  0.97193646,     -0.9810596,  -0.086276084,
    -0.83065766,  1.3572321,      -0.3264265,  0.9830234,      1.9656628,   -0.027371943,
    -0.2147214,   0.9708719,      0.7808455,   -2.8328671,     0.16847888,  0.46490908,
    -1.3624828,   -1.6547482,     2.0986745,   1.1753378,      0.9649557,   -2.1333718,
    -2.8838634,   -3.6214924,     -1.2048804,  1.4246187,      -1.5557631,  -0.14997569,
    -1.2367743,   -2.5178103,     1.0000296,   -0.05879204,    -4.0544314,  0.01142931,
    2.153687,     -0.078014135,   0.4878212,   -1.0468364,     -2.503492,   2.5305676,
    2.808617,     -1.2466786,     -0.33058444, 1.3549128,      0.41841656,  0.03719666,
    2.216088,     -1.8328552,     -0.95222485, 3.2528882,      -0.25863037, -0.91804826,
    -2.822532,    1.4063904,      -0.6193025,  -0.25615215};

template<class T>
static std::vector<T> convert(const std::vector<float>& v) {
    if (v.empty()) {
        return std::vector<T>();
    }

    size_t num_of_elems = v.size();
    std::vector<T> converted(num_of_elems);
    for (size_t i = 0; i < num_of_elems; ++i) {
        converted[i] = static_cast<T>(v[i]);
    }
    return converted;
}

template <class T>
static std::vector<T> convert(const std::vector<float16>& v) {
    if (v.empty()) {
        return std::vector<T>();
    }

    size_t num_of_elems = v.size();
    std::vector<T> converted(num_of_elems);
    for (size_t i = 0; i < num_of_elems; ++i) {
        converted[i] = static_cast<T>(v[i]);
    }
    return converted;
}

template <class T>
static std::vector<T> convert(const std::vector<bfloat16>& v) {
    if (v.empty()) {
        return std::vector<T>();
    }

    size_t num_of_elems = v.size();
    std::vector<T> converted(num_of_elems);
    for (size_t i = 0; i < num_of_elems; ++i) {
        converted[i] = static_cast<T>(v[i]);
    }
    return converted;
}

template <element::Type_t ET>
std::vector<RDFTParams> generateParamsForRDFT() {
    std::vector<RDFTParams> params{
        // rdft1d_eval
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft1d_results_1,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
                   NULL),
        // rdft1d_eval_signal_size_0
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft1d_results_1,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {10})),
        // rdft1d_eval_signal_size_0_1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft1d_results_1,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1})),
        // rdft1d_eval_1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft1d_results_1,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
                   NULL),
        // rdft1d_eval_signal_size_1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 3, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft1d_results_2,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {5})),
        // rdft1d_eval_signal_size_1_1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 3, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft1d_results_2,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {5})),
        // rdft1d_eval_signal_size_2
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 7, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft1d_results_3,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {12})),
        // rdft1d_eval_signal_size_2_1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 7, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft1d_results_3,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {12})),
        // rdft2d_eval_1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
                   NULL),
        // rdft2d_eval_1_positive_negative_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
                   NULL),
        // rdft2d_eval_1_negative_positive_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
                   NULL),
        // rdft2d_eval_1_negative_negative_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
                   NULL),
        // rdft2d_eval_1_signal_size_0_s10_10
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
        // rdft2d_eval_1_signal_size_0_s10_10_positive_negative_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
        // rdft2d_eval_1_signal_size_0_s10_10_negative_positive_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
        // rdft2d_eval_1_signal_size_0_s10_10_negative_negative_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
        // rdft2d_eval_1_signal_size_0_s10_m1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, -1})),
        // rdft2d_eval_1_signal_size_0_sm1_10
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-1, 10})),
        // rdft2d_eval_1_signal_size_0_sm1_m1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-1, -1})),
        // rdft2d_eval_2_signal_size
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 5, 7, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results_2,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
        // rdft2d_eval_2_signal_size_positive_negative_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 5, 7, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results_2,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
        // rdft2d_eval_2_signal_size_negative_positive_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 5, 7, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results_2,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
        // rdft2d_eval_2_signal_size_negative_negative_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 5, 7, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft2d_results_2,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
        // rdft3d_eval_1
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft3d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {0, 1, 2}),
                   NULL),
        // rdft3d_eval_1_negative_axes_and_signal_size
        RDFTParams(Shape{2, 10, 10},
                   Shape{2, 10, 6, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft3d_results,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-3, 1, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-1, 10, -1})),
        // rdft3d_eval_2
        RDFTParams(Shape{2, 10, 10},
                   Shape{4, 5, 7, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft3d_results_2,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {0, 1, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {4, 5, 12})),
        // rdft3d_eval_2_negative_axes
        RDFTParams(Shape{2, 10, 10},
                   Shape{4, 5, 7, 2},
                   ET,
                   ET,
                   input_data,
                   expected_rdft3d_results_2,
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-3, -2, 2}),
                   op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {4, 5, 12})),
    };

    return params;
}

std::vector<RDFTParams> generateCombinedParamsForRDFT() {
    const std::vector<std::vector<RDFTParams>> allTypeParams{
        generateParamsForRDFT<element::Type_t::f32>()
    };

    std::vector<RDFTParams> combinedParams;

    for (const auto& params : allTypeParams) {
        combinedParams.insert(combinedParams.end(), params.begin(), params.end());
    }

    return combinedParams;
}

INSTANTIATE_TEST_SUITE_P(
    smoke_RDFT_With_Hardcoded_Refs,
    ReferenceRDFTLayerTest,
    ::testing::ValuesIn(generateCombinedParamsForRDFT()),
    ReferenceRDFTLayerTest::getTestCaseName);
}  // namespace
